In [1]:
import torch
import torch.nn as nn
# from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.backends.cudnn as cudnn
import torchvision
import torch.autograd as autograd
from PIL import Image
import imp
import os
import sys
import math
import time
import random
import shutil
# import cv2
import scipy.misc
from glob import glob
import sklearn
import logging

from time import time
from tqdm import tqdm
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')

%matplotlib inline
In [2]:
if not torch.cuda.is_available():
    print("SORRY: No CUDA device")
In [3]:
imageSize = 64
batchSize = 64
In [19]:
nz = 100
ngf = 64
ndf = 64
nc = 3

nd = 2
cuda = True

Load data

In [5]:
PATH = 'celeba/'

data = dset.ImageFolder(PATH,
    transforms.Compose([
        transforms.Scale(imageSize),
        transforms.CenterCrop(imageSize),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

dataloader = DataLoader(data, batch_size=batchSize, shuffle=True)

Custom weights initialization called on netG and netD

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
In [61]:
class _netG(nn.Module):
    def __init__(self):
        super(_netG, self).__init__()
        self.ngpu = 1
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output
    
netG = _netG()
netG.apply(weights_init)
netG
Out[61]:
_netG (
  (main): Sequential (
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU (inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU (inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU (inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU (inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh ()
  )
)
In [62]:
class _netD(nn.Module):
    def __init__(self):
        super(_netD, self).__init__()
        self.ngpu = 1        
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        
        return output
        return output.view(-1, 1)

netDs = [_netD() for _ in range(nd)]
[netD.apply(weights_init) for netD in netDs]
Out[62]:
[_netD (
   (main): Sequential (
     (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (1): LeakyReLU (0.2, inplace)
     (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
     (4): LeakyReLU (0.2, inplace)
     (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
     (7): LeakyReLU (0.2, inplace)
     (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
     (10): LeakyReLU (0.2, inplace)
     (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
     (12): Sigmoid ()
   )
 ), _netD (
   (main): Sequential (
     (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (1): LeakyReLU (0.2, inplace)
     (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
     (4): LeakyReLU (0.2, inplace)
     (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
     (7): LeakyReLU (0.2, inplace)
     (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
     (10): LeakyReLU (0.2, inplace)
     (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
     (12): Sigmoid ()
   )
 )]

Setup input and output tensors

In [63]:
criterion = nn.BCELoss()

input = torch.FloatTensor(batchSize, 3, imageSize, imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(batchSize)
real_label = 1
fake_label = 0

if cuda:
    [netD.cuda() for netD in netDs]
    netG.cuda()
    criterion.cuda()
    input, label = input.cuda(), label.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

Setup optimizer

In [55]:
lr = 0.0002
beta1 = 0.5
In [64]:
optimizerDs = [optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) for netD in netDs]
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Train

In [57]:
niter = 1
In [65]:
tick = time()
losses = []
for epoch in range(niter):
    for i, data in enumerate(dataloader):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        
        # get next batch of data
        real_cpu, _ = data
        batch_size = real_cpu.size(0)
        
        # real input
        input.data.resize_(real_cpu.size()).copy_(real_cpu)
        
        # fake input
        noise.data.resize_(batch_size, nz, 1, 1)
        noise.data.normal_(0, 1)
        fake = netG(noise)
        
        # for each D
        for netD, optimizerD in zip(netDs, optimizerDs):
            netD.zero_grad()
            
            # train with real
            output = netD(input)
            label.data.resize_(batch_size).fill_(real_label)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.data.mean()

            # train with fake    
            output = netD(fake.detach())
            label.data.fill_(fake_label)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.data.mean()
            
            errD = errD_real + errD_fake
            optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        
        # randomly select a D
        # whichD = np.random.randint(0, nd)
        # netD = netDs[whichD]
        
        netG.zero_grad()
        
        output = Variable(torch.zeros(batch_size, 1, 1).cuda())
        for netD in netDs:
             output += netD(fake)
        output /= nd
        
        label.data.fill_(real_label)  # fake labels are real for generator cost
        
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()
        
        losses.append((errD.data[0], errG.data[0]))
        
        if i%100 == 0:
            print('[{}/{}][{}/{}][{:.2f}] Loss_D: {:.2f} Loss_G: {:.2f} D(x): {:.2f} D(G(z)): {:.2f} / {:.2f}'.\
            format(epoch, niter, i, len(dataloader), time() - tick, errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
        
        if i%100 == 0:
            fake = netG(fixed_noise).data
            show(vutils.make_grid(fake[:25], normalize=True, nrow=5).cpu())            
            #show(vutils.make_grid(fake, normalize=True).cpu())
[0/1][0/3166][0.62] Loss_D: 2.39 Loss_G: 3.24 D(x): 0.25 D(G(z)): 0.46 / 0.05
[0/1][100/3166][53.95] Loss_D: 0.01 Loss_G: 27.63 D(x): 0.99 D(G(z)): 0.00 / 0.00
[0/1][200/3166][107.22] Loss_D: 0.09 Loss_G: 6.11 D(x): 0.95 D(G(z)): 0.01 / 0.01
[0/1][300/3166][160.69] Loss_D: 0.24 Loss_G: 4.50 D(x): 0.86 D(G(z)): 0.04 / 0.02
[0/1][400/3166][214.26] Loss_D: 0.14 Loss_G: 5.73 D(x): 0.95 D(G(z)): 0.05 / 0.01
[0/1][500/3166][267.60] Loss_D: 0.42 Loss_G: 4.93 D(x): 0.80 D(G(z)): 0.07 / 0.01
[0/1][600/3166][321.14] Loss_D: 1.07 Loss_G: 5.40 D(x): 0.94 D(G(z)): 0.52 / 0.01
[0/1][700/3166][374.77] Loss_D: 0.24 Loss_G: 5.21 D(x): 0.92 D(G(z)): 0.11 / 0.01
[0/1][800/3166][428.32] Loss_D: 0.43 Loss_G: 5.11 D(x): 0.90 D(G(z)): 0.19 / 0.01
[0/1][900/3166][481.97] Loss_D: 0.58 Loss_G: 3.83 D(x): 0.78 D(G(z)): 0.16 / 0.04
[0/1][1000/3166][535.55] Loss_D: 0.42 Loss_G: 5.17 D(x): 0.92 D(G(z)): 0.25 / 0.01
[0/1][1100/3166][589.33] Loss_D: 1.79 Loss_G: 3.76 D(x): 0.30 D(G(z)): 0.00 / 0.04
[0/1][1200/3166][642.77] Loss_D: 0.42 Loss_G: 5.64 D(x): 0.80 D(G(z)): 0.09 / 0.01
[0/1][1300/3166][696.46] Loss_D: 0.55 Loss_G: 4.80 D(x): 0.79 D(G(z)): 0.20 / 0.02
[0/1][1400/3166][750.00] Loss_D: 1.31 Loss_G: 3.55 D(x): 0.88 D(G(z)): 0.60 / 0.06
[0/1][1500/3166][803.75] Loss_D: 0.70 Loss_G: 3.93 D(x): 0.79 D(G(z)): 0.24 / 0.03
[0/1][1600/3166][857.04] Loss_D: 0.56 Loss_G: 2.82 D(x): 0.74 D(G(z)): 0.08 / 0.10
[0/1][1700/3166][910.51] Loss_D: 0.40 Loss_G: 3.31 D(x): 0.77 D(G(z)): 0.06 / 0.06
[0/1][1800/3166][964.10] Loss_D: 0.59 Loss_G: 5.41 D(x): 0.92 D(G(z)): 0.34 / 0.01
[0/1][1900/3166][1017.65] Loss_D: 0.27 Loss_G: 4.43 D(x): 0.96 D(G(z)): 0.18 / 0.02
[0/1][2000/3166][1070.92] Loss_D: 0.36 Loss_G: 4.34 D(x): 0.91 D(G(z)): 0.16 / 0.03
[0/1][2100/3166][1124.63] Loss_D: 0.58 Loss_G: 3.78 D(x): 0.91 D(G(z)): 0.34 / 0.04
[0/1][2200/3166][1177.60] Loss_D: 0.54 Loss_G: 3.66 D(x): 0.68 D(G(z)): 0.03 / 0.05
[0/1][2300/3166][1231.01] Loss_D: 0.38 Loss_G: 2.98 D(x): 0.79 D(G(z)): 0.05 / 0.08
[0/1][2400/3166][1283.96] Loss_D: 0.76 Loss_G: 2.35 D(x): 0.64 D(G(z)): 0.11 / 0.13
[0/1][2500/3166][1336.96] Loss_D: 1.29 Loss_G: 2.05 D(x): 0.39 D(G(z)): 0.01 / 0.18
[0/1][2600/3166][1389.97] Loss_D: 0.34 Loss_G: 4.67 D(x): 0.88 D(G(z)): 0.14 / 0.02
[0/1][2700/3166][1442.95] Loss_D: 0.38 Loss_G: 3.29 D(x): 0.89 D(G(z)): 0.21 / 0.06
[0/1][2800/3166][1496.14] Loss_D: 0.41 Loss_G: 3.81 D(x): 0.89 D(G(z)): 0.22 / 0.04
[0/1][2900/3166][1549.16] Loss_D: 0.72 Loss_G: 4.77 D(x): 0.90 D(G(z)): 0.41 / 0.01
[0/1][3000/3166][1602.16] Loss_D: 0.34 Loss_G: 3.75 D(x): 0.88 D(G(z)): 0.16 / 0.04
[0/1][3100/3166][1655.08] Loss_D: 0.29 Loss_G: 4.25 D(x): 0.81 D(G(z)): 0.03 / 0.03
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-65-9613aba9a687> in <module>()
     45         output = Variable(torch.zeros(batchSize, 1, 1).cuda())
     46         for netD in netDs:
---> 47              output += netD(fake)
     48         output /= nd
     49 

/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py in __iadd__(self, other)
    747 
    748     def __iadd__(self, other):
--> 749         return self.add_(other)
    750 
    751     def __sub__(self, other):

/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py in add_(self, other)
    284 
    285     def add_(self, other):
--> 286         return self._add(other, True)
    287 
    288     def _sub(self, other, inplace):

/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py in _add(self, other, inplace)
    275     def _add(self, other, inplace):
    276         if isinstance(other, Variable):
--> 277             return Add(inplace)(self, other)
    278         else:
    279             assert not torch.is_tensor(other)

/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/basic_ops.py in forward(self, a, b)
     16         if self.inplace:
     17             self.mark_dirty(a)
---> 18             return a.add_(b)
     19         else:
     20             return a.add(b)

RuntimeError: sizes do not match at /py/conda-bld/pytorch_1493674854206/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:216

Show

In [59]:
def show(img, fs=(6,6)):
    plt.figure(figsize = fs)
    plt.imshow(np.transpose(img.numpy(), (1,2,0)))
    plt.show()
In [66]:
fake = netG(fixed_noise).data
In [67]:
show(vutils.make_grid(fake[:25], normalize=True, nrow=5).cpu())
In [ ]: